﻿/**
 * Copyright 2019-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "trans_tensor.h"

#include <list>
#include <cmath>
#include <string>
#include <vector>
#include "securec.h"

#include "infra/base/assertion.h"
#include "infra/math/fp16_t.h"

#include "framework/common/types.h"
#include "framework/infra/log/log.h"
#include "framework/graph/infershape/tensor_format_index_define.h"

#include "common/math/math_util.h"
#include "base/common/graph_tensor_util/cast_check_utils.h"

#if defined(ARM_NEON_32)
#include <arm_neon.h>
#endif

using namespace std;
using namespace ge;
/*
 * @ingroup dnn
 * @brief check whether expr is true, if not print log then return errorcode
 */
#ifndef CHECK
#define CHECK(expr, errorCode, errInfo, ...) \
    do { \
        if (!(expr)) { \
            return errorCode; \
        } \
    } while (0)
#endif

/*
 * @ingroup dnn
 * @brief check whether expr is true, if not then return errorcode
 */
#ifndef CHECK_ONLY_RET
#define CHECK_ONLY_RET(expr, errorCode) \
    do { \
        if (!(expr)) { \
            return errorCode; \
        } \
    } while (0)
#endif

/*
 * @ingroup dnn
 * @brief check whether obj is null, if null then print log and return fail
 */
#ifndef CHECK_NULL_WITH_RET
#define CHECK_NULL_WITH_RET(obj, retValue) \
    do { \
        if (obj == nullptr) { \
            return retValue; \
        } \
    } while (0)
#endif

#ifndef CEIL
#define CEIL(N, n) (((N) + (n) - 1) / (n))
#endif

#ifdef ARM_NEON
namespace {
inline void TransTensorFp32ToFp16_C8_neon(uint64_t inPtr, uint64_t outPtr)
{
    __asm__ __volatile("mov x2, %[inPtr]\n"
                       "mov x3, %[outPtr]\n"

                       "ld1 {v9.4s, v10.4s}, [x2], #32\n"
                       "fcvtn v11.4h, v9.4s\n"
                       "fcvtn2 v11.8h, v10.4s\n"
                       "st1 {v11.8h}, [x3], #16\n"
                       : [outPtr] "+r"(outPtr)
                       : [inPtr] "r"(inPtr)
                       : "x2", "x3", "v9", "v10", "v11", "cc");
}
}
#endif

namespace hiai {

/*
 * @ingroup fmk
 * @brief mode of data type transform
 */
typedef enum tagDataTypeTransMode {
    CC_DATATYPE_TRANS_FLOAT_NO_TRANS = 0, /* *< origin data is float, no trans */
    CC_DATATYPE_TRANS_FP16_NO_TRANS, /* *< origin data is fp16, no trans */
    CC_DATATYPE_TRANS_INT8_NO_TRANS, /* *< origin data is int8, no trans */
    CC_DATATYPE_TRANS_FLOAT_TO_FP16, /* *< data type float trans to fp16 */
    CC_DATATYPE_TRANS_FP16_TO_FLOAT, /* *< data type fp16 trans to float */
    CC_DATATYPE_TRANS_FLOAT_TO_INT8, /* *< data type float trans to int8 */
    CC_DATATYPE_TRANS_INT8_TO_FLOAT, /* *< data type int8 trans to float */
    CC_DATATYPE_TRANS_UINT8_TO_FLOAT, /* *< data type uint8 trans to float */
    CC_DATATYPE_TRANS_UINT8_NO_TRANS, /* *< origin data is uint8, no trans */
    CC_DATATYPE_TRANS_INT32_NO_TRANS, /* *< data type uint8 trans to float */
    CC_DATATYPE_TRANS_INT64_NO_TRANS, /* *< data type int64 trans to int64 */
    CC_DATATYPE_TRANS_MODE_RESERVED
} DataTypeTransMode_t;

/*
 * @ingroup dnn
 * @brief check whether uint64 multiplication can result in overflow
 * @param [in] a  multiplicator
 * @param [in] b  multiplicator
 * @return uint32_t
 */
inline uint32_t Uint64_addCheck(uint64_t a, uint64_t b)
{
    if (CheckUint64AddOverflow(a, b) != SUCCESS) {
        return FAILED;
    }
    return SUCCESS;
}

/*
 * @ingroup dnn
 * @brief max element number of tensor
 */
const int MAX_TENSOR_ELEMENT_COUNT = 2000000000;

/*
 * @ingroup dnn
 * @brief align size which input and output data actually ocupy
 */
const int DATA_MEMORY_ALIGN_SIZE = 32;

/*
 * @ingroup dnn
 * @brief dimcnt of 4d tensor = 4
 */
const int CC_4D_FORMAT_DIMCNT = 4;

/*
 * @ingroup dnn
 * @brief max number of dimensions when use NC1HWC0 format
 */
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
const int CC_REALDIM_MAX = 4;
#pragma GCC diagnostic pop

/*
 * @ingroup fmk
 * @kernel ub address align size(bit)
 */
const int CC_INT8_C0_SIZE = 32;

/*
 * @ingroup dnn
 * @brief check whether int32 multiplication of input array can result in overflow
 * @param [in] array the array to execute multiplication
 * @param [in] num the number of array elements
 * @param [in|out] result point to the result of multiplication
 * @return uint32_t
 */
uint32_t CheckInt32ArrayMulOverflow(int32_t array[], int32_t num, int32_t* result)
{
    if (result == nullptr) {
        return FAILED;
    }

    int32_t product = array[0];
    for (int32_t i = 1; i < num; i++) {
        if (CheckInt32MulOverflow(product, array[i]) != SUCCESS) {
            return FAILED;
        }
        product *= array[i];
    }
    *result = product;
    return SUCCESS;
}

static Status GetTensorNdDescriptor(const ccTensor_t& tensorDesc, uint32_t arrLength, DataType_t& dataType,
    int32_t& dimCnt, int32_t dimA[], int32_t strideA[]);

static Status GetTensor4dDescriptor(const ccTensor_t& tensorDesc, DataType_t& dataType, int32_t& n, int32_t& c,
    int32_t& h, int32_t& w);

static Status Check5DTensorOverFlow(const ccTensor_t& tensorDesc)
{
    if (tensorDesc.format == FORMAT_NCDHW || tensorDesc.format == FORMAT_NDHWC) {
        int32_t productNum = 0;
        int32_t dims[5] = {
            tensorDesc.dim[0], tensorDesc.dim[1], tensorDesc.dim[2], tensorDesc.dim[3], tensorDesc.dim[4]};
        // check overflow
        if (CheckInt32ArrayMulOverflow(dims, 5, &productNum) != SUCCESS) {
            FMK_LOGE("5D Tesnor dims multiplication can result in overflow!");
            return FAILED;
        }
        if (productNum > MAX_TENSOR_ELEMENT_COUNT) {
            FMK_LOGE("5D Tensor element count is too large!");
            return FAILED;
        }
    } else {
        FMK_LOGE("unkown format.");
        return FAILED;
    }
    return SUCCESS;
}

static Status CheckTensorOverFlow(const ccTensor_t& tensorDesc)
{
    DataType_t dataType = CC_DATA_RESERVED;
    int32_t dimCnt = 0;
    int32_t dim[CC_DIM_MAX] = {0, 0, 0, 0, 0, 0, 0, 0};
    int32_t stride[CC_DIM_MAX] = {0, 0, 0, 0, 0, 0, 0, 0};
    int32_t n = -1;
    int32_t c = -1;
    int32_t h = -1;
    int32_t w = -1;

    if (tensorDesc.format == FORMAT_ND) {
        CHECK_ONLY_RET(GetTensorNdDescriptor(tensorDesc, CC_DIM_MAX, dataType, dimCnt, dim, stride) == SUCCESS, FAILED);
    } else if ((tensorDesc.format == FORMAT_NCDHW) || (tensorDesc.format == FORMAT_NDHWC)) {
        FMK_LOGI("now is 5D format %d.", static_cast<int32_t>(tensorDesc.format));
    } else {
        CHECK_ONLY_RET(
            GetTensor4dDescriptor(tensorDesc, dataType, n, c, h, w) == SUCCESS,
            FAILED);
    }
    switch (tensorDesc.format) {
        case FORMAT_ND: {
            int32_t productNum = 0;
            // check overflow
            CHECK((CheckInt32ArrayMulOverflow(dim, dimCnt, &productNum) == SUCCESS), FAILED,
                "Tesnor dims multiplication can result in overflow!");
            CHECK((productNum <= MAX_TENSOR_ELEMENT_COUNT), FAILED, "Tensor element count is too large!");
            break;
        }
        case FORMAT_NCHW:
        case FORMAT_NHWC: {
            int32_t productNum = 0;
            int32_t dims[4] = {n, c, h, w};
            // check overflow
            CHECK((CheckInt32ArrayMulOverflow(dims, 4, &productNum) == SUCCESS), FAILED,
                "Tesnor dims multiplication can result in overflow!");
            CHECK((productNum <= MAX_TENSOR_ELEMENT_COUNT), FAILED, "Tensor element count is too large!");
            break;
        }
        case FORMAT_NCDHW:
        case FORMAT_NDHWC: {
            if (Check5DTensorOverFlow(tensorDesc) != SUCCESS) {
                return FAILED;
            };
            break;
        }
        default:
            return FAILED;
    }
    return SUCCESS;
}

/*
 * @ingroup dnn
 * @brief get the size of data type in device
 * @param [in] dataType   data type in device
 * @param [in|out] size   size of data type
 * @return Status
 */
static Status GetDataTypeSize(DataType_t dataType, uint32_t& size)
{
    struct TYPE_SIZE {
        DataType_t ccType;
        uint8_t size;
    };
    using TypeSize_t = struct TYPE_SIZE;

    TypeSize_t TypeSize[CC_DATA_RESERVED] = {
        {CC_DATA_FLOAT, sizeof(float)},
        {CC_DATA_HALF_UINT16_PROPOSAL, sizeof(fp16_t)},
        {CC_DATA_HALF, sizeof(fp16_t)},
        {CC_DATA_INT8, sizeof(int8_t)},
        {CC_DATA_INT32, sizeof(int32_t)},
        {CC_DATA_UINT8, sizeof(uint8_t)},
        {CC_DATA_UINT32, sizeof(uint32_t)},
        {CC_DATA_INT16, sizeof(int16_t)},
        {CC_DATA_UINT16, sizeof(uint16_t)},
        {CC_DATA_INT64, sizeof(int64_t)},
        {CC_DATA_UINT64, sizeof(uint64_t)},
        {CC_DATA_DOUBLE, sizeof(double)},
        {CC_DATA_BOOL, sizeof(uint8_t)},
        {CC_DATA_DUAL, sizeof(fp16_t) + sizeof(int8_t)},
        {CC_DATA_DUAL_SUB_UINT8, sizeof(int8_t)},
        {CC_DATA_DUAL_SUB_INT8, sizeof(int8_t)},
        {CC_DATA_QINT8, sizeof(int8_t)},
        {CC_DATA_QUINT8, sizeof(uint8_t)},
        {CC_DATA_QINT16, sizeof(int16_t)},
        {CC_DATA_QUINT16, sizeof(uint16_t)},
        {CC_DATA_QINT32, sizeof(int32_t)},
        {CC_DATA_2BITS, sizeof(int8_t)},
    };
    int i = 0;
    size = 0;
    for (i = 0; i < CC_DATA_RESERVED; i++) {
        if (TypeSize[i].ccType == dataType) {
            size = TypeSize[i].size;
            break;
        }
    }
    if ((i >= CC_DATA_RESERVED) || (size == 0)) {
        return FAILED;
    }
    return SUCCESS;
};

static Status SetNCHWTensorDimAndCalcCount(
    ccTensor_t& tensorDesc, std::vector<int32_t>& dims, int32_t& elementCnt)
{
    int32_t productNCHW = 0;
    int32_t operandsNCHW[4] = {dims[NCHW_DIM_N], dims[NCHW_DIM_C], dims[NCHW_DIM_H], dims[NCHW_DIM_W]};
    tensorDesc.dim[0] = dims[NCHW_DIM_N];
    tensorDesc.dim[1] = dims[NCHW_DIM_C];
    tensorDesc.dim[2] = dims[NCHW_DIM_H];
    tensorDesc.dim[3] = dims[NCHW_DIM_W];
    // check overflow
    if (CheckInt32ArrayMulOverflow(operandsNCHW, 4, &productNCHW) != SUCCESS) {
        return FAILED;
    }
    if (productNCHW > MAX_TENSOR_ELEMENT_COUNT) {
        return FAILED;
    }
    elementCnt = productNCHW;
    return SUCCESS;
}

static Status SetNHWCTensorDimAndCalcCount(
    ccTensor_t& tensorDesc, std::vector<int32_t>& dims, int32_t& elementCnt)
{
    tensorDesc.dim[0] = dims[NCHW_DIM_N];
    tensorDesc.dim[1] = dims[NCHW_DIM_H];
    tensorDesc.dim[2] = dims[NCHW_DIM_W];
    tensorDesc.dim[3] = dims[NCHW_DIM_C];
    int32_t productNHWC = 0;
    int32_t operandsNHWC[4] = {dims[NCHW_DIM_N], dims[NCHW_DIM_H], dims[NCHW_DIM_W], dims[NCHW_DIM_C]};
    // check overflow
    if (CheckInt32ArrayMulOverflow(operandsNHWC, 4, &productNHWC) != SUCCESS) {
        return FAILED;
    }
    if (productNHWC > MAX_TENSOR_ELEMENT_COUNT) {
        return FAILED;
    }
    elementCnt = productNHWC;
    return SUCCESS;
}

static Status SetNC1HWC0TensorDimAndCalcCount(
    ccTensor_t& tensorDesc, DataType_t dataType, std::vector<int32_t>& dims, int32_t& elementCnt)
{
    int32_t c0 =
        ((dataType == CC_DATA_BOOL) || (dataType == CC_DATA_INT8) || (dataType == CC_DATA_UINT8) ||
            (dataType == CC_DATA_DUAL_SUB_UINT8) || (dataType == CC_DATA_QINT8) || (dataType == CC_DATA_QUINT8) ||
            (dataType == CC_DATA_DUAL_SUB_INT8) || (dataType == CC_DATA_2BITS)) ?
        CC_INT8_C0_SIZE :
        CC_CUBE_SIZE;
    int32_t c1 = static_cast<int32_t>(std::ceil(dims[NCHW_DIM_C] * 1.0 / c0));
    int32_t productNC1HWC0 = 0;
    int32_t operandsNC1HWC0[5] = {dims[NCHW_DIM_N], c1, dims[NCHW_DIM_H], dims[NCHW_DIM_W], c0};
    tensorDesc.dim[0] = dims[NCHW_DIM_N];
    tensorDesc.dim[1] = dims[NCHW_DIM_C];
    tensorDesc.dim[2] = dims[NCHW_DIM_H];
    tensorDesc.dim[3] = dims[NCHW_DIM_W];
    // check overflow
    if (CheckInt32ArrayMulOverflow(operandsNC1HWC0, 5, &productNC1HWC0) != SUCCESS) {
        return FAILED;
    }
    if (productNC1HWC0 > MAX_TENSOR_ELEMENT_COUNT) {
        return FAILED;
    }
    elementCnt = productNC1HWC0;
    return SUCCESS;
}

HCS_API_EXPORT Status SetTensor4dDescriptor(
    ccTensor_t& tensorDesc, ge::Format format, DataType_t dataType, std::vector<int32_t>& dims)
{
    if ((dims[NCHW_DIM_N] <= 0) || (dims[NCHW_DIM_C] <= 0) || (dims[NCHW_DIM_H] <= 0) || (dims[NCHW_DIM_W] <= 0)) {
        return FAILED;
    }

    tensorDesc.dataType = dataType;
    tensorDesc.dimCnt = CC_4D_FORMAT_DIMCNT;
    tensorDesc.format = format;
    // calc data type size
    uint32_t dataTypeSize = 0;
    Status ret = GetDataTypeSize(dataType, dataTypeSize);
    CHECK((ret == SUCCESS), ret, "GetDataTypeSize failed, ret is %d!", ret);

    tensorDesc.dataSize = dataTypeSize;
    int32_t elementCnt = 1;
    switch (format) {
        case FORMAT_NCHW:
            ret = SetNCHWTensorDimAndCalcCount(tensorDesc, dims, elementCnt);
            CHECK_ONLY_RET((ret == SUCCESS), ret);
            break;
        case FORMAT_NHWC:
            ret = SetNHWCTensorDimAndCalcCount(tensorDesc, dims, elementCnt);
            CHECK_ONLY_RET((ret == SUCCESS), ret);
            break;
        case FORMAT_NC1HWC0:
        case FORMAT_C1HWNC0:
            ret = SetNC1HWC0TensorDimAndCalcCount(tensorDesc, dataType, dims, elementCnt);
            CHECK_ONLY_RET((ret == SUCCESS), ret);
            break;
        default:
            FMK_LOGE("format %d is not supported!", format);
            return FAILED;
    }
    FMK_UINT32_MULCHECK(tensorDesc.dataSize, static_cast<uint32_t>(elementCnt));
    tensorDesc.dataSize *= static_cast<uint32_t>(elementCnt);
    return SUCCESS;
}

static Status GetTensor4dDescriptor(const ccTensor_t& tensorDesc, DataType_t& dataType, int32_t& n, int32_t& c,
    int32_t& h, int32_t& w)
{
    int32_t dimCnt = tensorDesc.dimCnt;
    if (dimCnt != CC_4D_FORMAT_DIMCNT) {
        return FAILED;
    }
    dataType = tensorDesc.dataType;
    switch (tensorDesc.format) {
        case FORMAT_NCHW:
            n = tensorDesc.dim[0];
            c = tensorDesc.dim[1];
            h = tensorDesc.dim[2];
            w = tensorDesc.dim[3];
            break;
        case FORMAT_NHWC:
            n = tensorDesc.dim[0];
            h = tensorDesc.dim[1];
            w = tensorDesc.dim[2];
            c = tensorDesc.dim[3];
            break;
        default:
            return FAILED;
    }
    return SUCCESS;
}

/* for ND case: suppoort 5D now. */
HCS_API_EXPORT Status SetNDTensorDescriptor(
    ccTensor_t& tensorDesc, ge::Format format, DataType_t dataType, std::vector<int32_t>& dims)
{
    if (((format != FORMAT_NCDHW) && (format != FORMAT_NDHWC)) || (dims.size() != 5)) {
        FMK_LOGE("para is error. format: %d, dimSize: %zu", static_cast<int32_t>(format), dims.size());
        return FAILED;
    }
    // set ND info
    tensorDesc.dataType = dataType;
    tensorDesc.dimCnt = static_cast<int32_t>(dims.size());
    tensorDesc.realDimCnt = -1;
    tensorDesc.format = format;
    tensorDesc.dim[0] = dims[0];
    tensorDesc.dim[1] = dims[1];
    tensorDesc.dim[2] = dims[2];
    tensorDesc.dim[3] = dims[3];
    tensorDesc.dim[4] = dims[4];
    // calc data type size
    uint32_t dataTypeSize = 0;
    Status ret = GetDataTypeSize(dataType, dataTypeSize);
    if (ret != SUCCESS) {
        FMK_LOGE("GetDataTypeSize failed, ret is %d!", ret);
        return ret;
    }
    tensorDesc.dataSize = dataTypeSize;

    int32_t elementCnt = 0;
    // check overflow
    if (CheckInt32ArrayMulOverflow(dims.data(), static_cast<int32_t>(dims.size()), &elementCnt) != SUCCESS) {
        FMK_LOGE("Integer multiplication can result in overflow!");
        return FAILED;
    }
    if (static_cast<int>(elementCnt) > MAX_TENSOR_ELEMENT_COUNT) {
        FMK_LOGE("The tensor element count %d is too large!", elementCnt);
        return FAILED;
    }
    FMK_UINT32_MULCHECK(tensorDesc.dataSize, static_cast<uint32_t>(elementCnt));
    tensorDesc.dataSize *= static_cast<uint32_t>(elementCnt);
    return SUCCESS;
}

HCS_API_EXPORT Status SetTensorNdDescriptor(ccTensor_t& tensorDesc, DataType_t dataType, int32_t dimCnt, int32_t dimA[])
{
    if ((dimCnt > 0) && (dimA == nullptr)) {
        return FAILED;
    }

    if ((dimCnt < 0) || (dimCnt > CC_DIM_MAX)) {
        return FAILED;
    }

    tensorDesc.dataType = dataType;
    tensorDesc.dimCnt = dimCnt;
    tensorDesc.format = FORMAT_ND;

    // calc data type size
    uint32_t dataTypeSize = 0;
    Status ret = GetDataTypeSize(dataType, dataTypeSize);
    if (ret != SUCCESS) {
        return ret;
    }
    tensorDesc.dataSize = dataTypeSize;

    int32_t elementCnt = 1;
    for (int32_t i = 0; i < tensorDesc.dimCnt; i++) {
        if (dimA[i] <= 0) {
            return FAILED;
        }
        tensorDesc.dim[i] = dimA[i];
        FMK_INT32_MULCHECK(elementCnt, tensorDesc.dim[i]);
        elementCnt *= tensorDesc.dim[i];
    }

    if (elementCnt > MAX_TENSOR_ELEMENT_COUNT) {
        return FAILED;
    }

    FMK_INT32_MULCHECK(static_cast<int32_t>(tensorDesc.dataSize), static_cast<int32_t>(elementCnt));
    tensorDesc.dataSize *= static_cast<uint32_t>(elementCnt);

    return SUCCESS;
}

static Status GetTensorNdDescriptor(const ccTensor_t& tensorDesc, uint32_t arrLength, DataType_t& dataType,
    int32_t& dimCnt, int32_t dimA[], int32_t strideA[])
{
    if ((dimA == nullptr) || (strideA == nullptr)) {
        return FAILED;
    }

    if (tensorDesc.format != FORMAT_ND) {
        return FAILED;
    }

    dataType = tensorDesc.dataType;
    dimCnt = tensorDesc.dimCnt;

    CHECK(static_cast<uint32_t>(dimCnt) <= arrLength, FAILED, "dimCntReq are not allowed to exceed %u.", arrLength);
    for (int32_t i = 0; i < dimCnt; i++) {
        dimA[i] = tensorDesc.dim[i];
    }
    return SUCCESS;
}

HCS_API_EXPORT Status GetTensorMemorySizeInBytes(const ccTensor_t& tensorDesc, uint32_t& size)
{
    FMK_UINT32_ADDCHECK(tensorDesc.dataSize, 2 * DATA_MEMORY_ALIGN_SIZE - 1);

    FMK_UINT32_MULCHECK(
        (tensorDesc.dataSize + 2 * DATA_MEMORY_ALIGN_SIZE - 1) / DATA_MEMORY_ALIGN_SIZE, DATA_MEMORY_ALIGN_SIZE);

    size = ((tensorDesc.dataSize + 2 * DATA_MEMORY_ALIGN_SIZE - 1) / DATA_MEMORY_ALIGN_SIZE) * DATA_MEMORY_ALIGN_SIZE;
    return SUCCESS;
}

static Status GetDataTypeTransModeFunc0(const DataType_t xType, DataTypeTransMode_t& dataTypeTransmode)
{
    if (xType == CC_DATA_HALF) {
        dataTypeTransmode = CC_DATATYPE_TRANS_FP16_NO_TRANS;
    } else if (xType == CC_DATA_FLOAT) {
        dataTypeTransmode = CC_DATATYPE_TRANS_FLOAT_NO_TRANS;
    } else if (xType == CC_DATA_INT8) {
        dataTypeTransmode = CC_DATATYPE_TRANS_INT8_NO_TRANS;
    } else if (xType == CC_DATA_INT32) {
        dataTypeTransmode = CC_DATATYPE_TRANS_INT32_NO_TRANS;
    } else if (xType == CC_DATA_INT64) {
        dataTypeTransmode = CC_DATATYPE_TRANS_INT64_NO_TRANS;
    } else if ((xType == CC_DATA_UINT8) || (xType == CC_DATA_BOOL) || (xType == CC_DATA_QUINT8)) {
        dataTypeTransmode = CC_DATATYPE_TRANS_UINT8_NO_TRANS;
    } else {
        return FAILED;
    }
    return SUCCESS;
}

static Status GetDataTypeTransMode(
    const DataType_t xType, const DataType_t yType, DataTypeTransMode_t& dataTypeTransmode)
{
    if (xType == yType) {
        return GetDataTypeTransModeFunc0(xType, dataTypeTransmode);
    } else if ((xType == CC_DATA_FLOAT) && (yType == CC_DATA_HALF)) {
        dataTypeTransmode = CC_DATATYPE_TRANS_FLOAT_TO_FP16;
    } else if ((xType == CC_DATA_HALF) && (yType == CC_DATA_FLOAT)) {
        dataTypeTransmode = CC_DATATYPE_TRANS_FP16_TO_FLOAT;
    } else if ((xType == CC_DATA_UINT8 || xType == CC_DATA_QUINT8) && (yType == CC_DATA_FLOAT)) {
        dataTypeTransmode = CC_DATATYPE_TRANS_UINT8_TO_FLOAT;
    } else if ((xType == CC_DATA_INT8) && (yType == CC_DATA_FLOAT)) {
        dataTypeTransmode = CC_DATATYPE_TRANS_INT8_TO_FLOAT;
    } else {
        return FAILED;
    }
    return SUCCESS;
}

#define CopyToDstByDatatypetransmode(x, y, srcIdx, dstIdx) \
    do { \
        switch (dataTypeTransmode) { \
            case CC_DATATYPE_TRANS_FP16_TO_FLOAT: \
                fp = static_cast<const fp16_t*>(x)[srcIdx]; \
                static_cast<float*>(y)[dstIdx] = fp; \
                break; \
            case CC_DATATYPE_TRANS_FLOAT_TO_FP16 : \
                fp = static_cast<const float*>(x)[srcIdx]; \
                static_cast<uint16_t*>(y)[dstIdx] = fp.val; \
                break; \
            case CC_DATATYPE_TRANS_FP16_NO_TRANS: \
                static_cast<fp16_t*>(y)[dstIdx] = static_cast<const fp16_t*>(x)[srcIdx]; \
                break; \
            case CC_DATATYPE_TRANS_UINT8_TO_FLOAT: \
                static_cast<float*>(y)[dstIdx] = static_cast<float>(static_cast<const uint8_t*>(x)[srcIdx]); \
                break; \
            case CC_DATATYPE_TRANS_INT8_TO_FLOAT: \
                static_cast<float*>(y)[dstIdx] = static_cast<float>(static_cast<const int8_t*>(x)[srcIdx]); \
                break; \
            case CC_DATATYPE_TRANS_UINT8_NO_TRANS: \
                static_cast<uint8_t*>(y)[dstIdx] = static_cast<const uint8_t*>(x)[srcIdx]; \
                break; \
            case CC_DATATYPE_TRANS_INT8_NO_TRANS: \
                static_cast<int8_t*>(y)[dstIdx] = static_cast<const int8_t*>(x)[srcIdx]; \
                break; \
            case CC_DATATYPE_TRANS_FLOAT_NO_TRANS: \
                static_cast<float*>(y)[dstIdx] = static_cast<const float*>(x)[srcIdx]; \
                break; \
            case CC_DATATYPE_TRANS_INT32_NO_TRANS: \
                static_cast<int32_t*>(y)[dstIdx] = static_cast<const int32_t*>(x)[srcIdx]; \
                break; \
            case CC_DATATYPE_TRANS_INT64_NO_TRANS: \
                static_cast<int64_t*>(y)[dstIdx] = static_cast<const int64_t*>(x)[srcIdx]; \
                break; \
            default: \
                return FAILED; \
        } \
    } while (0)

static Status TransTensorNHWCToNCHW(const ccTensor_t& xDesc, const void* x, const ccTensor_t& yDesc, void* y)
{
    uint32_t n = static_cast<uint32_t>(yDesc.dim[0]);
    uint32_t c = static_cast<uint32_t>(yDesc.dim[1]);
    uint32_t h = static_cast<uint32_t>(yDesc.dim[2]);
    uint32_t w = static_cast<uint32_t>(yDesc.dim[3]);
    uint32_t dstIdx = 0;
    uint32_t idx = 0;
    fp16_t fp;

    DataTypeTransMode_t dataTypeTransmode = CC_DATATYPE_TRANS_FLOAT_NO_TRANS;
    if (GetDataTypeTransMode(xDesc.dataType, yDesc.dataType, dataTypeTransmode) != SUCCESS) {
        return FAILED;
    }
    for (uint32_t nIdx = 0; nIdx < n; nIdx++) {
        for (uint32_t cIdx = 0; cIdx < c; cIdx++) {
            for (uint32_t hIdx = 0; hIdx < h; hIdx++) {
                for (uint32_t wIdx = 0; wIdx < w; wIdx++) {
                    idx = cIdx + wIdx * c + hIdx * w * c + nIdx * h * w * c; // (n h w c)
                    CopyToDstByDatatypetransmode(x, y, idx, dstIdx);
                    dstIdx++;
                }
            }
        }
    }
    return SUCCESS;
}

Status TransTensorFloatToHALF(const ccTensor_t& xDesc, const void* x, const ccTensor_t& yDesc, void* y)
{
    CHECK_NULL_WITH_RET(x, FAILED);
    CHECK_NULL_WITH_RET(y, FAILED);

    uint32_t dataCnt = xDesc.dataSize / sizeof(float);
    if (yDesc.dataSize < dataCnt * sizeof(fp16_t)) {
        FMK_LOGE("outputDataSize:%u not enough!", yDesc.dataSize);
        return FAILED;
    }

#if defined(ARM_NEON)
    uint32_t loopTime = dataCnt / 8;
    uint32_t lastTime = dataCnt % 8;
    uint64_t inPtr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(x)) + lastTime * 4;
    uint64_t outPtr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(y)) + lastTime * 2;
    for (uint32_t i = 0; i < loopTime; i++) {
        TransTensorFp32ToFp16_C8_neon(inPtr, outPtr);
        inPtr = inPtr + 32;
        outPtr = outPtr + 16;
    }
    for (uint32_t i = 0; i < lastTime; i++) {
        fp16_t fp16Data;
        fp16Data = static_cast<const float*>(x)[i];
        static_cast<uint16_t*>(y)[i] = fp16Data.val;
    }

#elif defined(ARM_NEON_32)

    uint32_t loopTime = dataCnt / 4;
    uint32_t lastTime = dataCnt % 4;
    uint64_t inPtr = (uint64_t)((uintptr_t)x) + lastTime * 4;
    uint64_t outPtr = (uint64_t)((uintptr_t)y) + lastTime * 2;

    float16x4_t tmpfp16;
    float32x4_t tmpfp32;
    for (uint32_t i = 0; i < loopTime; i++) {
        tmpfp32 = vld1q_f32(reinterpret_cast<float32_t*>(inPtr));
        tmpfp16 = vcvt_f16_f32(tmpfp32);
        vst1_f16(reinterpret_cast<__fp16*>(outPtr), tmpfp16);
        inPtr = inPtr + 16;
        outPtr = outPtr + 8;
    }
    for (uint32_t i = 0; i < lastTime; i++) {
        fp16_t fp16Data;
        fp16Data = static_cast<const float*>(x)[i];
        static_cast<uint16_t*>(y)[i] = fp16Data.val;
    }
#else
    for (uint32_t i = 0; i < dataCnt; i++) {
        fp16_t fp16Data;
        fp16Data = static_cast<const float*>(x)[i];
        static_cast<uint16_t*>(y)[i] = fp16Data.val;
    }
#endif
    return SUCCESS;
}

static Status TransTensorFloatToFloat(const ccTensor_t& xDesc, const float* x, const ccTensor_t& yDesc, float* y)
{
    uint32_t dataCnt = xDesc.dataSize / sizeof(float);
    if (yDesc.dataSize < dataCnt * sizeof(float)) {
        return FAILED;
    }
    CHECK_ONLY_RET(memcpy_s(y, yDesc.dataSize, x, dataCnt * sizeof(float)) == 0, FAILED);
    return SUCCESS;
}

static Status TransTensorINT32ToINT32(const ccTensor_t& xDesc, const int32_t* x, const ccTensor_t& yDesc, int32_t* y)
{
    uint32_t dataCnt = xDesc.dataSize / sizeof(int32_t);
    if (yDesc.dataSize < dataCnt * sizeof(int32_t)) {
        return FAILED;
    }
    for (uint32_t i = 0; i < dataCnt; i++) {
        y[i] = x[i];
    }
    return SUCCESS;
}

static Status TransTensorINT64ToINT32(const ccTensor_t& xDesc, const int64_t* x, const ccTensor_t& yDesc, int32_t* y)
{
    uint32_t dataCnt = xDesc.dataSize / sizeof(int64_t);
    if (yDesc.dataSize < dataCnt * sizeof(int32_t)) {
        return FAILED;
    }
    for (uint32_t i = 0; i < dataCnt; i++) {
        y[i] = static_cast<int32_t>(x[i]);
    }
    return SUCCESS;
}

static Status TransTensorINT64ToFloat(const ccTensor_t& xDesc, const int64_t* x, const ccTensor_t& yDesc, float* y)
{
    uint32_t dataCnt = xDesc.dataSize / sizeof(int64_t);
    if (yDesc.dataSize < dataCnt * sizeof(float)) {
        return FAILED;
    }
    for (uint32_t i = 0; i < dataCnt; i++) {
        y[i] = static_cast<float>(x[i]);
    }
    return SUCCESS;
}

static Status TransTensorHALFToUINT8(const ccTensor_t& xDesc, const void* x, const ccTensor_t& yDesc, void* y)
{
    uint32_t dataCnt = xDesc.dataSize / sizeof(fp16_t);
    if (yDesc.dataSize < dataCnt * sizeof(uint8_t)) {
        FMK_LOGE("outputDataSize:%u not enough!", yDesc.dataSize);
        return FAILED;
    }

    for (uint32_t i = 0; i < dataCnt; i++) {
        fp16_t fp16Data = (static_cast<const uint16_t*>(x))[i];
        static_cast<uint8_t*>(y)[i] = fp16Data.toUInt8();
    }
    return SUCCESS;
}

#define DoTransTensor(xDesc, x, yDesc, y, ySizeInBytes) \
    do { \
        if (((xDesc).format == (yDesc).format) || \
            (((xDesc).format == FORMAT_ND) && ((yDesc).format == FORMAT_NCHW)) || \
            (((xDesc).format == FORMAT_NCHW) && ((yDesc).format == FORMAT_ND)) || \
            (((xDesc).format == FORMAT_NHWC) && ((yDesc).format == FORMAT_ND))) { \
            if (((xDesc).dataType == CC_DATA_FLOAT) && ((yDesc).dataType == CC_DATA_HALF)) { \
                return TransTensorFloatToHALF(xDesc, x, yDesc, y); \
            } else if (((xDesc).dataType == CC_DATA_HALF) && ((yDesc).dataType == CC_DATA_FLOAT)) { \
                return TransTensorHALFToFloat(xDesc, x, yDesc, y); \
            } else if (((xDesc).dataType == CC_DATA_FLOAT) && ((yDesc).dataType == CC_DATA_FLOAT)) { \
                return TransTensorFloatToFloat(xDesc, static_cast<const float*>(x), yDesc, static_cast<float*>(y)); \
            } else if (((xDesc).dataType == CC_DATA_INT32) && ((yDesc).dataType == CC_DATA_INT32)) { \
                return TransTensorINT32ToINT32( \
                    (xDesc), (static_cast<const int32_t*>(x)), (yDesc), (static_cast<int32_t*>(y))); \
            } else if (((xDesc).dataType == CC_DATA_INT64) && ((yDesc).dataType == CC_DATA_INT32)) { \
                return TransTensorINT64ToINT32( \
                    (xDesc), (static_cast<const int64_t*>(x)), (yDesc), (static_cast<int32_t*>(y))); \
            } else if (((xDesc).dataType == CC_DATA_INT64) && ((yDesc).dataType == CC_DATA_FLOAT)) { \
                return TransTensorINT64ToFloat( \
                    (xDesc), (static_cast<const int64_t*>(x)), (yDesc), (static_cast<float*>(y))); \
            } else if (((xDesc).dataType == CC_DATA_HALF) && ((yDesc).dataType == CC_DATA_UINT8)) { \
                return TransTensorHALFToUINT8(xDesc, x, yDesc, y); \
            } else { \
                return FAILED; \
            } \
        } else if (((xDesc).format == FORMAT_NHWC) && ((yDesc).format == FORMAT_NCHW)) { \
            return TransTensorNHWCToNCHW(xDesc, x, yDesc, y); \
        } else { \
            return FAILED; \
        } \
    } while (0)

HCS_API_EXPORT Status TransTensor(
    const ccTensor_t& xDesc, const ge::BaseBuffer& input,
    const ccTensor_t& yDesc, ge::BaseBuffer& output, uint32_t ySizeInBytes)
{
    if ((input.GetData() == nullptr) || (output.GetData() == nullptr)) {
        return FAILED;
    }

    // buffer size check
    if (xDesc.dataSize > input.GetSize() || yDesc.dataSize > output.GetSize()) {
        FMK_LOGE("size is not match");
        return FAILED;
    }

    CHECK(CheckTensorOverFlow(xDesc) == SUCCESS, FAILED, "xDesc is overflow!!!");
    CHECK(CheckTensorOverFlow(yDesc) == SUCCESS, FAILED, "yDesc is overflow!!!");
    CHECK(xDesc.dimCnt <= CC_DIM_MAX, FAILED, "input tensor's dimCnt is out of range");
    uint32_t xSizeInBytes = sizeof(uint16_t);
    CHECK(GetDataTypeSize(xDesc.dataType, xSizeInBytes) == SUCCESS, FAILED, "not support this dataType.");
    for (int32_t i = 0; i < xDesc.dimCnt; ++i) {
        xSizeInBytes *= xDesc.dim[i];
    }
    CHECK(xDesc.dataSize >= xSizeInBytes, FAILED, "input data size is error.");

    if (yDesc.dataSize > ySizeInBytes) {
        FMK_LOGE("calc size fail or output data size is too small!");
        return FAILED;
    }
    DoTransTensor(xDesc, static_cast<const void*>(input.GetData()),
        yDesc, reinterpret_cast<void*>(output.GetData()), ySizeInBytes);
}

HCS_API_EXPORT Status TransferDim(const std::vector<int64_t>& dim, std::vector<int32_t>& dimVector)
{
    uint32_t inputShapeSize = dim.size();
    std::list<uint32_t> newDimList;

    for (auto dim_temp : dim) {
        newDimList.push_back(dim_temp);
    }

    if (inputShapeSize > 5) {
        FMK_LOGE("Cannot support inputShapeSize %u", inputShapeSize);
        return FAILED;
    }
    switch (inputShapeSize) {
        case 0: {
            newDimList.push_back(1);
            newDimList.push_back(1);
            newDimList.push_back(1);
            newDimList.push_back(1);
            break;
        }
        case 1: {
            newDimList.push_front(1);
            newDimList.push_back(1);
            newDimList.push_back(1);
            break;
        }
        case 2: {
            newDimList.push_front(1);
            newDimList.push_back(1);
            break;
        }
        case 3: {
            newDimList.push_front(1);
            break;
        }
        default:
            break;
    }

    dimVector.clear();
    for (auto dimNew : newDimList) {
        dimVector.push_back(dimNew);
    }
    return SUCCESS;
}
} // namespace hiai
